import json
import numpy as np
from ex import set_configs
from sklearn.metrics import confusion_matrix
import keras.backend as K
import os
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from utils.Models import VGG_16, choose_model
from DataUtils.Load_util import load_local_train_val_DEBUG, load_local_train_val, select_drivers
from skimage import io
import cv2


def show_curve():
    json_path = '/media/dell/cb552bf1-c649-4cca-8aca-3c24afca817b/dell/wxm/Data/KaggleDistractedDriverDetection/experiments/80x60/baseline/' \
                'M-VGG_16_L-categorical_crossentropy_E-50_CF-softmax.json'
    f = open(json_path, 'r')
    result = json.load(f)
    acc = result['training']['acc']
    val_acc = result['training']['val_acc']
    loss = result['training']['loss']
    val_loss = result['training']['val_loss']
    log_loss = result['training']['log_loss']
    val_los_loss = result['training']['val_log_loss']
    acc_curve = np.asarray([acc, val_acc])
    loss_curve = np.asarray([loss, val_loss])
    log_loss_curve = np.asarray([log_loss, val_los_loss])
    plt.subplot(1, 3, 1)
    plt.plot(acc_curve.transpose())
    plt.subplot(1, 3, 2)
    plt.plot(loss_curve.transpose())
    plt.subplot(1, 3, 3)
    plt.plot(log_loss_curve.transpose())
    plt.show()


def show_confusion_matrix(config):
    weight_path = '/media/dell/cb552bf1-c649-4cca-8aca-3c24afca817b/dell/wxm/Data/KaggleDDD/experiments/224x224/baseline' \
                  'M-VGG_16_L-categorical_crossentropy_E-3_CF-softmax_BS-128_CR-True.h5'
    config['crop_right'] = 'True'
    # config['resize']
    model = choose_model(config)
    model.load_weights(weight_path)

    image_narrays, lable_narrays, drivers_id, unique_drivers = load_local_train_val(config=config)
    unique_list_val = unique_drivers[-1]
    val_X, val_Y, index = select_drivers(image_narrays, lable_narrays, drivers_id, unique_list_val)
    pred = model.predict_classes(val_X)
    labels = [
        'safe driving',
        'texting - right',
        'talking on the phone - right',
        'texting - left',
        'talking on the phone - left',
        'operating the radio',
        'drinking',
        'reaching behind',
        'hair and makeup',
        'talking to passenger']
    plt.matshow(confusion_matrix(val_Y, pred), cmap='Reds', interpolation='none')
    plt.yticks(np.arange(10), labels)
    plt.xticks(np.arange(10), labels, rotation=90)


def show_heat_map():
    config, _ = set_configs()
    config['to_gray'] = 'False'
    config['weights_path'] = None

    with K.tf.device('/gpu:1'):
        convnet = VGG_16(config, heat_map=False)
        convnet.compile('adam', 'categorical_crossentropy')
        convnet.load_weights(
            '/media/dell/cb552bf1-c649-4cca-8aca-3c24afca817b/dell/wxm/Data/KaggleDDD/experiments/224x224/baseline/'
            'M-VGG_16_E-3_BS-128_CR-False_CC-True_Aug-False_ConTi-False_L-categorical_crossentropy_CF-softmax.h5')
    convnet_heatmap = VGG_16(config, heat_map=True)
    for layer in convnet_heatmap.layers:
        if layer.name.startswith("conv"):
            orig_layer = convnet.get_layer(layer.name)
            layer.set_weights(orig_layer.get_weights())
        elif layer.name.startswith("dense"):
            orig_layer = convnet.get_layer(layer.name)
            W, b = orig_layer.get_weights()
            n_filter, previous_filter, ax1, ax2 = layer.get_weights()[0].shape
            new_W = W.reshape((previous_filter, ax1, ax2, n_filter))
            new_W = new_W.transpose((3, 0, 1, 2))
            new_W = new_W[:, :, ::-1, ::-1]
            layer.set_weights([new_W, b])
    with K.tf.device('/gpu:1'):
        convnet_heatmap.compile(optimizer='adam', loss='mse')

    img_fold = '/media/dell/cb552bf1-c649-4cca-8aca-3c24afca817b/dell/wxm/Data/KaggleDDD/train/c0'
    for img in os.listdir(img_fold):
        img_path = img_fold + '/' + img
        im = io.imread(img_path)
        plt.subplot(1, 2, 1)
        io.imshow(im)
        im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
        im /= 255.
        im = im.transpose((2, 0, 1))
        im = np.expand_dims(im, 0)
        out = convnet_heatmap.predict(im)
        print out.shape
        heatmap = out[0, 0, :, :]
        print heatmap
        plt.subplot(1, 2, 2)
        io.imshow(heatmap * 255)
        io.show()

def t_SNE_visualization():
    # X = np.array([[0, 0, 0], [0, 1, 1], [1, 0, 1], [1, 1, 1]])
    X = np.load('/media/dell/cb552bf1-c649-4cca-8aca-3c24afca817b/dell/wxm/Data/KaggleDDD/npy/VGG_FC7_X.npy')
    y = np.load('/media/dell/cb552bf1-c649-4cca-8aca-3c24afca817b/dell/wxm/Data/KaggleDDD/npy/VGG_FC7_y.npy')
    # X, y = sklearn.utils.shuffle(X, y)
    print 'data loaded'
    model = TSNE(n_components=2, random_state=0, n_iter=1000, verbose=1)
    np.set_printoptions(suppress=True)
    X_embeded = model.fit_transform(X)
    print 't-SNE finished'
    np.save('metadata/X_embedded.npy', X_embeded)
    print X_embeded.shape
    fig = plt.figure()
    plt.scatter(X_embeded[:, 0], X_embeded[:, 1], c=y)
    fig.savefig('metadata/X_embedded')

if __name__ == '__main__':
    t_SNE_visualization()
